package net.gazhi.delonix.core.web;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

/**
 * 自动处理编码 uri 参数编码，解决 get 参数乱码问题<br>
 * 代码摘自 <a href="http://www.oschina.net/code/snippet_12_2"
 * target="_blank">《OSChina 中一个非常重要的类——RequestContext》</a><br>
 * <br>
 * 如果在 tomcat 的 server.xml 里的 Connector 配置了 URIEncoding="UTF-8" 或者
 * useBodyEncodingForURI=true， 那么不要使用本过滤器。
 * 
 * 把 containerEncoding 和 actualEncoding 都设置为 UTF-8，也可以忽略本过滤器的作用
 * 
 * @see http://www.oschina.net/code/snippet_12_2
 * 
 * @author Jeffrey Lin
 *
 */
public class UriParamEncodingFilter extends OncePerRequestFilter {

	private final static boolean isResinGtV3;
	static {
		isResinGtV3 = checkResinVersion();
	}

	private final static boolean checkResinVersion() {
		try {
			// 3.0 以上版本的 Resin 无需对URL参数进行转码
			Class<?> verClass = Class.forName("com.caucho.Version");
			String ver = (String) verClass.getDeclaredField("VERSION").get(verClass);
			String mainVer = ver.substring(0, ver.lastIndexOf('.'));
			return Float.parseFloat(mainVer) > 3.0;
		} catch (Throwable t) {
		}
		return false;
	}

	/**
	 * 容器使用的编码
	 */
	private String containerEncoding = "8859_1";

	/**
	 * 实际的编码
	 */
	private String actualEncoding = "UTF-8";

	public void setOriginalEncoding(String originalEncoding) {
		this.containerEncoding = originalEncoding;
	}

	public void setTargetEncoding(String targetEncoding) {
		this.actualEncoding = targetEncoding;
	}

	/**
	 * Filter 方法
	 */
	@Override
	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
		boolean ignore = "POST".equalsIgnoreCase(request.getMethod()) || isResinGtV3;
		ignore = ignore || StringUtils.isEmpty(this.containerEncoding) || StringUtils.isEmpty(this.actualEncoding) || containerEncoding.equals(actualEncoding);
		HttpServletRequest req = ignore ? request : new UriParamEncodingRequest(request, containerEncoding, actualEncoding);
		filterChain.doFilter(req, response);
	}

	/**
	 * 实现自动转码的 request 类
	 * 
	 * @author Jeffrey Lin
	 *
	 */
	public static class UriParamEncodingRequest extends HttpServletRequestWrapper {

		private String originalEncoding;

		private String targetEncoding;

		/**
		 * 私有构造函数
		 * 
		 * @param request
		 * @param encoding
		 */
		private UriParamEncodingRequest(HttpServletRequest request, String originalEncoding, String targetEncoding) {
			super(request);
			this.originalEncoding = originalEncoding;
			this.targetEncoding = targetEncoding;
		}

		/**
		 * 重载getParameter
		 */
		public String getParameter(String paramName) {
			String value = super.getParameter(paramName);
			return decodeParamValue(value);
		}

		/**
		 * 重载 getParameterMap
		 */
		@SuppressWarnings({ "unchecked", "rawtypes" })
		public Map<String, Object> getParameterMap() {
			Map params = super.getParameterMap();
			HashMap<String, Object> newParams = new HashMap<String, Object>();
			Iterator<String> it = params.keySet().iterator();
			while (it.hasNext()) {
				String key = (String) it.next();
				if (params.get(key).getClass().isArray()) {
					String[] values = (String[]) params.get(key);
					String[] newValues = new String[values.length];
					for (int i = 0; i < values.length; i++) {
						newValues[i] = decodeParamValue(values[i]);
					}
					newParams.put(key, newValues);
				} else {
					String value = (String) params.get(key);
					String newValue = decodeParamValue(value);
					newParams.put(key, newValue);
				}
			}
			return newParams;
		}

		/**
		 * 重载getParameterValues
		 */
		public String[] getParameterValues(String key) {
			String[] values = super.getParameterValues(key);
			for (int i = 0; values != null && i < values.length; i++) {
				values[i] = decodeParamValue(values[i]);
			}
			return values;
		}

		/**
		 * 参数转码
		 * 
		 * @param value
		 * @return
		 */
		private String decodeParamValue(String value) {
			if (StringUtils.isEmpty(value)) {
				return value;
			}
			try {
				return new String(value.getBytes(this.originalEncoding), this.targetEncoding);
			} catch (Exception e) {
				throw new RuntimeException(e);
			}
		}
	}

}